{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f3b3efd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# type: ignore"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OpenAI Supervised Fine-Tuning\n",
    "\n",
    "This recipe allows TensorZero users to fine-tune OpenAI models using their own data.\n",
    "Since TensorZero automatically logs all inferences and feedback, it is straightforward to fine-tune a model using your own data and any prompt you want.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get started:\n",
    "\n",
    "- Set the `TENSORZERO_CLICKHOUSE_URL` environment variable. For example: `TENSORZERO_CLICKHOUSE_URL=\"http://chuser:chpassword@localhost:8123/tensorzero\"`\n",
    "- Set the `OPENAI_API_KEY` environment variable.\n",
    "- Update the following parameters:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG_PATH = \"../../../examples/data-extraction-ner/config/tensorzero.toml\"\n",
    "\n",
    "FUNCTION_NAME = \"extract_entities\"\n",
    "\n",
    "METRIC_NAME = \"jaccard_similarity\"\n",
    "\n",
    "# The name of the variant to use to grab the templates used for fine-tuning\n",
    "TEMPLATE_VARIANT_NAME = \"gpt_4o_mini\"\n",
    "\n",
    "# If the metric is a float metric, you can set the threshold to filter the data\n",
    "FLOAT_METRIC_THRESHOLD = 0.5\n",
    "\n",
    "# Fraction of the data to use for validation\n",
    "VAL_FRACTION = 0.2\n",
    "\n",
    "# Maximum number of samples to use for fine-tuning\n",
    "MAX_SAMPLES = 100_000\n",
    "\n",
    "# The name of the model to fine-tune (supported models: https://platform.openai.com/docs/guides/fine-tuning)\n",
    "MODEL_NAME = \"gpt-4o-mini-2024-07-18\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "tensorzero_path = os.path.abspath(os.path.join(os.getcwd(), \"../../../\"))\n",
    "if tensorzero_path not in sys.path:\n",
    "    sys.path.append(tensorzero_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import toml\n",
    "from IPython.display import clear_output\n",
    "from tensorzero import (\n",
    "    FloatMetricFilter,\n",
    "    OpenAISFTConfig,\n",
    "    OptimizationJobStatus,\n",
    "    TensorZeroGateway,\n",
    ")\n",
    "\n",
    "from recipes.util import train_val_split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the TensorZero client\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorzero_client = TensorZeroGateway.build_embedded(\n",
    "    config_file=CONFIG_PATH,\n",
    "    clickhouse_url=os.environ[\"TENSORZERO_CLICKHOUSE_URL\"],\n",
    "    timeout=15,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the metric filter as needed\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comparison_operator = \">=\"\n",
    "metric_node = FloatMetricFilter(\n",
    "    metric_name=METRIC_NAME,\n",
    "    value=FLOAT_METRIC_THRESHOLD,\n",
    "    comparison_operator=comparison_operator,\n",
    ")\n",
    "\n",
    "# from tensorzero import BooleanMetricFilter\n",
    "\n",
    "# metric_node = BooleanMetricFilter(\n",
    "#     metric_name=METRIC_NAME,\n",
    "#     value=True  # or False\n",
    "# )\n",
    "\n",
    "metric_node"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Query the inferences from ClickHouse\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stored_inferences = tensorzero_client.experimental_list_inferences(\n",
    "    function_name=FUNCTION_NAME,\n",
    "    variant_name=None,\n",
    "    output_source=\"inference\",  # could also be \"demonstration\"\n",
    "    filters=metric_node,\n",
    "    limit=MAX_SAMPLES,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Render the inputs using the templates.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rendered_samples = tensorzero_client.experimental_render_samples(\n",
    "    stored_inferences=stored_inferences,\n",
    "    variants={FUNCTION_NAME: TEMPLATE_VARIANT_NAME},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split the data into training and validation sets for fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_samples, val_samples = train_val_split(\n",
    "    rendered_samples,\n",
    "    val_size=VAL_FRACTION,\n",
    "    last_inference_only=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Launch the fine tuning job"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimization_config = OpenAISFTConfig(\n",
    "    model=MODEL_NAME,\n",
    ")\n",
    "\n",
    "job_handle = tensorzero_client.experimental_launch_optimization(\n",
    "    train_samples=train_samples,\n",
    "    val_samples=val_samples,\n",
    "    optimization_config=optimization_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Wait for the fine-tuning job to complete.\n",
    "\n",
    "This cell will take a while to run.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "while True:\n",
    "    clear_output(wait=True)\n",
    "\n",
    "    try:\n",
    "        job_info = tensorzero_client.experimental_poll_optimization(job_handle=job_handle)\n",
    "        print(job_info)\n",
    "        if job_info.status in (\n",
    "            OptimizationJobStatus.Completed,\n",
    "            OptimizationJobStatus.Failed,\n",
    "        ):\n",
    "            break\n",
    "    except Exception as e:\n",
    "        print(f\"Error: {e}\")\n",
    "\n",
    "    time.sleep(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the fine-tuning job is complete, you can add the fine-tuned model to your config file.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fine_tuned_model = job_info.output[\"routing\"][0]\n",
    "model_config = {\n",
    "    \"models\": {\n",
    "        fine_tuned_model: {\n",
    "            \"routing\": [\"openai\"],\n",
    "            \"providers\": {\"openai\": {\"type\": \"openai\", \"model_name\": fine_tuned_model}},\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "print(toml.dumps(model_config))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, add a new variant to your function to use the fine-tuned model.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You're all set!\n",
    "\n",
    "You can change the weight to enable a gradual rollout of the new model.\n",
    "\n",
    "You might also add other parameters (e.g. `temperature`) to the variant section in the config file.\n"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "ipynb,py:percent",
   "main_language": "python"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
